from rknnlite.api import RKNNLite
from .base import Backend
import numpy as np
from geartrain.utils.import_utils import is_torch_available

if is_torch_available():
    import torch


class RKNN_Backend(Backend):
    def __init__(self, model_path=None, device="rk3588", **kwargs):
        self.data_format = kwargs["data_format"]
        self.model_path = model_path
        self.fp16 = True
        super().__init__(device, "rknn")

    def _load_model(self):
        self.rknn_lite = RKNNLite()
        ret = self.rknn_lite.load_rknn(self.model_path)

        if ret != 0:
            print("Load RKNN model failed")
            exit(ret)
        print("done")

        # 初始化 runtime 环境
        print("--> Init runtime environment")
        # run on RK356x/RK3588 with Debian OS, do not need specify target.
        ret = self.rknn_lite.init_runtime()
        if ret != 0:
            print("Init runtime environment failed!")
            exit(ret)
        print("done")

    def _forward(self, model_inputs):
        if not isinstance(model_inputs, np.ndarray):
            _inputs = model_inputs.cpu().numpy()
        else:
            _inputs = model_inputs

        if self.fp16:
            _inputs = _inputs.astype(np.float16)
        
        if self.data_format == "nchw":
            _inputs = _inputs.transpose([0,2,3,1])
        elif self.data_format == "ncwh":
            _inputs = _inputs.transpose([0,3,2,1])

        y = self.rknn_lite.inference([_inputs], data_format=["nhwc"])

        if isinstance(y, (list, tuple)):
            return (
                self.from_numpy(y[0])
                if len(y) == 1
                else [self.from_numpy(x) for x in y]
            )
        else:
            return self.from_numpy(y)

    def from_numpy(self, x):
        """
        Convert a numpy array to a tensor.

        Args:
            x (np.ndarray): The array to be converted.

        Returns:
            (torch.Tensor): The converted tensor
        """
        return x.astype(np.float32)
